package com.example.config;

import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * 限流服务 - 支持Redis和内存两种模式
 */
@Slf4j
public class RateLimitService {

    private final RedisTemplate<String, Object> redisTemplate;
    private final Map<String, RateLimitEntry> memoryCache;
    private final boolean useRedis;

    // 内存缓存条目
    private static class RateLimitEntry {
        int count;
        long expireTime;

        RateLimitEntry(int count, long expireTime) {
            this.count = count;
            this.expireTime = expireTime;
        }
    }

    // Redis模式构造函数
    public RateLimitService(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
        this.memoryCache = null;
        this.useRedis = true;
        log.info("限流服务启动 - 使用Redis模式");
    }

    // 内存模式构造函数
    public RateLimitService(Map<String, Object> unused) {
        this.redisTemplate = null;
        this.memoryCache = new ConcurrentHashMap<>();
        this.useRedis = false;
        log.info("限流服务启动 - 使用内存模式");
    }

    /**
     * 检查是否超过限流
     * @param key 限流键
     * @param limit 限制次数
     * @param window 时间窗口（秒）
     * @return true表示允许访问，false表示被限流
     */
    public boolean isAllowed(String key, int limit, int window) {
        if (useRedis) {
            return isAllowedRedis(key, limit, window);
        } else {
            return isAllowedMemory(key, limit, window);
        }
    }

    private boolean isAllowedRedis(String key, int limit, int window) {
        try {
            String redisKey = "rate_limit:" + key;

            // 获取当前计数
            Integer count = (Integer) redisTemplate.opsForValue().get(redisKey);

            if (count == null) {
                // 第一次访问，设置计数为1
                redisTemplate.opsForValue().set(redisKey, 1, window, TimeUnit.SECONDS);
                return true;
            } else if (count < limit) {
                // 未超过限制，计数加1
                redisTemplate.opsForValue().increment(redisKey);
                return true;
            } else {
                // 超过限制
                log.warn("限流触发: key={}, count={}, limit={}", key, count, limit);
                return false;
            }
        } catch (Exception e) {
            log.error("限流检查异常: {}", e.getMessage(), e);
            // 异常情况下允许访问，避免影响正常业务
            return true;
        }
    }

    private boolean isAllowedMemory(String key, int limit, int window) {
        try {
            long currentTime = System.currentTimeMillis();
            long expireTime = currentTime + (window * 1000L);

            RateLimitEntry entry = memoryCache.get(key);

            if (entry == null || currentTime > entry.expireTime) {
                // 第一次访问或已过期，重新计数
                memoryCache.put(key, new RateLimitEntry(1, expireTime));
                return true;
            } else if (entry.count < limit) {
                // 未超过限制，计数加1
                entry.count++;
                return true;
            } else {
                // 超过限制
                log.warn("限流触发: key={}, count={}, limit={}", key, entry.count, limit);
                return false;
            }
        } catch (Exception e) {
            log.error("限流检查异常: {}", e.getMessage(), e);
            // 异常情况下允许访问，避免影响正常业务
            return true;
        }
    }

    /**
     * 登录限流检查
     */
    public boolean checkLoginLimit(String ip) {
        return isAllowed("login:" + ip, 5, 300); // 5分钟内最多5次登录尝试
    }

    /**
     * API调用限流检查
     */
    public boolean checkApiLimit(String userId) {
        return isAllowed("api:" + userId, 1000, 3600); // 1小时内最多1000次API调用
    }

    /**
     * 消息发送限流检查
     */
    public boolean checkMessageLimit(String userId) {
        return isAllowed("message:" + userId, 100, 3600); // 1小时内最多100条消息
    }

    /**
     * 文件上传限流检查
     */
    public boolean checkUploadLimit(String userId) {
        return isAllowed("upload:" + userId, 50, 3600); // 1小时内最多50次上传
    }

    /**
     * 获取剩余次数
     */
    public int getRemainingCount(String key, int limit) {
        try {
            String redisKey = "rate_limit:" + key;
            Integer count = (Integer) redisTemplate.opsForValue().get(redisKey);
            return count == null ? limit : Math.max(0, limit - count);
        } catch (Exception e) {
            log.error("获取剩余次数异常: {}", e.getMessage(), e);
            return limit;
        }
    }

    /**
     * 重置限流计数
     */
    public void resetLimit(String key) {
        try {
            String redisKey = "rate_limit:" + key;
            redisTemplate.delete(redisKey);
            log.info("重置限流计数: {}", key);
        } catch (Exception e) {
            log.error("重置限流计数异常: {}", e.getMessage(), e);
        }
    }
}
